#!/usr/bin/env python3

"""
Generate syscall tables from the Linux source. Used to generate
manticore/platforms/linux_syscalls.py.

This fetches the tables from kernel.org.

Usage:

    ./extract_syscalls.py [--linux_version linux_version] linux_syscalls.py

"""

import argparse
import os
import re
import subprocess
import sys
import tempfile
from urllib.request import urlopen

BASE_URL = (
    "https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/plain/{}?id=refs/tags/v{}"
)

# Use an associative list rather than a dict to get deterministic output.
ARCH_TABLES = [
    ("i386", "arch/x86/entry/syscalls/syscall_32.tbl"),
    ("amd64", "arch/x86/entry/syscalls/syscall_64.tbl"),
    ("armv7", "arch/arm/tools/syscall.tbl"),
]

BITSPERLONG_HDR = "arch/{}/include/uapi/asm/bitsperlong.h"
ARCH_UNISTD_HDR = "arch/{}/include/uapi/asm/unistd.h"
UNISTD_HDR = "include/uapi/asm-generic/unistd.h"

# Format: Manticore arch, Linux arch.
# XXX: Code that uses this might need to be tweaked for other architectures to
# work properly.
UNISTD = [("aarch64", "arm64")]

__ARM_NR_BASE = 0xF0000
ADDITIONAL_SYSCALLS = {
    "armv7": [
        ("sys_ARM_NR_breakpoint", __ARM_NR_BASE + 1),
        ("sys_ARM_NR_cacheflush", __ARM_NR_BASE + 2),
        ("sys_ARM_NR_usr26", __ARM_NR_BASE + 3),
        ("sys_ARM_NR_usr32", __ARM_NR_BASE + 4),
        ("sys_ARM_NR_set_tls", __ARM_NR_BASE + 5),
    ]
}


def open_url(url):
    res = urlopen(url)
    if res.code // 100 != 2:
        sys.stderr.write("Failed retrieving file; check version and connection.\n")
        sys.stderr.write(f"Url: {url}\n")
        sys.exit(1)
    return res


def write_without_includes(f, res):
    for line in res.readlines():
        line = line.decode()
        line = line.strip()
        if not line.startswith("#include"):
            f.write(line + "\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate syscall tables")
    parser.add_argument("output", help="Python output to generate tables")
    parser.add_argument(
        "--linux_version", help="Major version of the Linux kernel to use", default="4.11"
    )
    args = parser.parse_args()

    output = open(args.output, "w+")
    output.write("#\n#\n# AUTOGENERATED, DO NOT EDIT\n#\n\n")
    output.write(f'LINUX_KERNEL_VERSION = "{args.linux_version}"\n\n')

    for arch, path in ARCH_TABLES:
        url = BASE_URL.format(path, args.linux_version)
        res = open_url(url)

        output.write(f"{arch} = {{\n")

        for line in res.readlines():
            line = line.decode()
            line = line.strip()
            if line.startswith("#"):
                continue
            columns = line.split()
            if len(columns) < 4:
                continue
            num, abi, name, entry = columns[:4]
            output.write(f'    {num}: "{entry}",\n')

        for entry, num in ADDITIONAL_SYSCALLS.get(arch, {}):
            output.write(f'    {num}: "{entry}",\n')

        output.write("}\n")

    for march, larch in UNISTD:
        bitsperlong_hdr = BITSPERLONG_HDR.format(larch)
        arch_unistd_hdr = ARCH_UNISTD_HDR.format(larch)

        bitsperlong_url = BASE_URL.format(bitsperlong_hdr, args.linux_version)
        arch_unistd_url = BASE_URL.format(arch_unistd_hdr, args.linux_version)
        unistd_url = BASE_URL.format(UNISTD_HDR, args.linux_version)

        bitsperlong_res = open_url(bitsperlong_url)
        arch_unistd_res = open_url(arch_unistd_url)
        unistd_res = open_url(unistd_url)

        syscall_rx = "SYSCALL: (\d+) ([a-z_0-9]+)"
        syscall_define = "#define __SYSCALL(nr, sym) SYSCALL: nr sym"

        output.write(f"{march} = {{\n")

        fd, tmp_path = tempfile.mkstemp()
        try:
            with os.fdopen(fd, "w") as tmp:
                # The order is important here for CPP to work correctly.
                tmp.write(syscall_define + "\n")
                write_without_includes(tmp, bitsperlong_res)
                write_without_includes(tmp, arch_unistd_res)
                write_without_includes(tmp, unistd_res)

            process = subprocess.Popen(
                ["cpp", "-E", tmp_path], stdout=subprocess.PIPE, encoding="utf-8"
            )
            out, _ = process.communicate()
            lines = out.split("\n")
            for line in lines:
                m = re.search(syscall_rx, line)
                if m:
                    num = m.group(1)
                    entry = m.group(2)
                    if entry != "sys_ni_syscall":  # not implemented syscall
                        output.write(f'    {num}: "{entry}",\n')
        finally:
            os.remove(tmp_path)

        for entry, num in ADDITIONAL_SYSCALLS.get(march, {}):
            output.write(f'    {num}: "{entry}",\n')

        output.write("}\n")
